import torch
import torch.nn as nn

seq_len = 3
batch_size = 1
input_size = 4
hidden_size = 2

cell = nn.RNNCell(input_size=input_size, hidden_size=hidden_size)

dataset = torch.randn(seq_len, batch_size, input_size)  # 数据

hidden = torch.zeros(batch_size, hidden_size)

for idx, input in enumerate(dataset):
    print(input.shape)
    hidden = cell(input, hidden)
    print(hidden.shape)
    print(f"第 {idx + 1} 步的隐藏状态: {hidden}")
